7867d7
@@ -29,18 +29,13 @@
 import org.apache.hadoop.hive.ql.metadata.HiveException;
 import org.apache.hadoop.hive.serde.serdeConstants;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter;
 import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.ByteObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.FloatObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.ShortObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.TimestampObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableHiveDecimalObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableStringObjectInspector;
 import org.apache.hadoop.io.Text;
 
 /**
@@ -58,6 +53,8 @@
     + "  \"Hello World 100 days\"")
 public class GenericUDFPrintf extends GenericUDF {
   private transient ObjectInspector[] argumentOIs;
+  protected transient Converter converterFormat;
+
   private final Text resultText = new Text();
 
   @Override
@@ -67,12 +64,25 @@
public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumen
           "The function PRINTF(String format, Obj... args) needs at least one arguments.");
     }
 
-    if (arguments[0].getTypeName() != serdeConstants.STRING_TYPE_NAME
-      && arguments[0].getTypeName() != serdeConstants.VOID_TYPE_NAME) {
+    WritableStringObjectInspector resultOI = PrimitiveObjectInspectorFactory.writableStringObjectInspector;
+
+    if (arguments[0].getCategory() == ObjectInspector.Category.PRIMITIVE) {
+      PrimitiveObjectInspector poi = ((PrimitiveObjectInspector) arguments[0]);
+      if (poi.getPrimitiveCategory() == PrimitiveObjectInspector.PrimitiveCategory.STRING ||
+          poi.getPrimitiveCategory() == PrimitiveObjectInspector.PrimitiveCategory.CHAR ||
+          poi.getPrimitiveCategory() == PrimitiveObjectInspector.PrimitiveCategory.VARCHAR ||
+          poi.getPrimitiveCategory() == PrimitiveObjectInspector.PrimitiveCategory.VOID) {
+        converterFormat = ObjectInspectorConverters.getConverter(arguments[0], resultOI);
+      } else {
         throw new UDFArgumentTypeException(0, "Argument 1"
-        + " of function PRINTF must be \"" + serdeConstants.STRING_TYPE_NAME
-        + "\", but \"" + arguments[0].getTypeName() + "\" was found.");
+            + " of function PRINTF must be \"" + serdeConstants.STRING_TYPE_NAME
+            + "\", but \"" + arguments[0].getTypeName() + "\" was found.");
       }
+    } else {
+      throw new UDFArgumentTypeException(0, "Argument 1"
+          + " of function PRINTF must be \"" + serdeConstants.STRING_TYPE_NAME
+          + "\", but \"" + arguments[0].getTypeName() + "\" was found.");
+    }
 
     for (int i = 1; i < arguments.length; i++) {
       if (!arguments[i].getCategory().equals(Category.PRIMITIVE)){
@@ -83,58 +93,46 @@
public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumen
     }
 
     argumentOIs = arguments;
-    return PrimitiveObjectInspectorFactory.writableStringObjectInspector;
+    return resultOI;
   }
 
   @Override
   public Object evaluate(DeferredObject[] arguments) throws HiveException {
+    // If the first argument is null, return null. (It's okay for other arguments to be null, in
+    // which case, "null" will be printed.)
+    if (arguments[0].get() == null) {
+      return null;
+    }
+
     StringBuilder sb = new StringBuilder();
     Formatter formatter = new Formatter(sb, Locale.US);
 
-    String pattern = ((StringObjectInspector) argumentOIs[0])
-        .getPrimitiveJavaObject(arguments[0].get());
+    Text pattern = (Text)converterFormat.convert(arguments[0].get());
 
-    ArrayList argumentList = new ArrayList();
+    ArrayList<Object> argumentList = new ArrayList<Object>();
     for (int i = 1; i < arguments.length; i++) {
       switch (((PrimitiveObjectInspector)argumentOIs[i]).getPrimitiveCategory()) {
         case BOOLEAN:
-          argumentList.add(((BooleanObjectInspector)argumentOIs[i]).get(arguments[i].get()));
-          break;
         case BYTE:
-          argumentList.add(((ByteObjectInspector)argumentOIs[i]).get(arguments[i].get()));
-          break;
         case SHORT:
-          argumentList.add(((ShortObjectInspector)argumentOIs[i]).get(arguments[i].get()));
-          break;
         case INT:
-          argumentList.add(((IntObjectInspector)argumentOIs[i]).get(arguments[i].get()));
-          break;
         case LONG:
-          argumentList.add(((LongObjectInspector)argumentOIs[i]).get(arguments[i].get()));
-          break;
         case FLOAT:
-          argumentList.add(((FloatObjectInspector)argumentOIs[i]).get(arguments[i].get()));
-          break;
         case DOUBLE:
-          argumentList.add(((DoubleObjectInspector)argumentOIs[i]).get(arguments[i].get()));
-          break;
+        case CHAR:
+        case VARCHAR:
         case STRING:
-          argumentList.add(((StringObjectInspector)argumentOIs[i])
-            .getPrimitiveJavaObject(arguments[i].get()));
-          break;
         case TIMESTAMP:
-          argumentList.add(((TimestampObjectInspector)argumentOIs[i])
+        case DECIMAL:
+          argumentList.add(((PrimitiveObjectInspector)argumentOIs[i])
             .getPrimitiveJavaObject(arguments[i].get()));
           break;
-        case BINARY:
-          argumentList.add(arguments[i].get());
-          break;
         default:
           argumentList.add(arguments[i].get());
           break;
       }
     }
-    formatter.format(pattern, argumentList.toArray());
+    formatter.format(pattern.toString(), argumentList.toArray());
 
     resultText.set(sb.toString());
     return resultText;
